{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda3\\lib\\site-packages\\sklearn\\externals\\joblib\\__init__.py:15: DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n",
      "  warnings.warn(msg, category=DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "from scipy.interpolate import UnivariateSpline\n",
    "from sklearn import linear_model\n",
    "import xgboost as xgb\n",
    "from sklearn.externals import joblib\n",
    "from sklearn.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_rows',150)\n",
    "pd.set_option('display.max_columns',500)\n",
    "pd.set_option('display.width',1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>link_ID</th>\n",
       "      <th>date</th>\n",
       "      <th>time_interval_begin</th>\n",
       "      <th>travel_time</th>\n",
       "      <th>imputation1</th>\n",
       "      <th>lagging1</th>\n",
       "      <th>lagging2</th>\n",
       "      <th>lagging3</th>\n",
       "      <th>lagging4</th>\n",
       "      <th>lagging5</th>\n",
       "      <th>length</th>\n",
       "      <th>area</th>\n",
       "      <th>vacation</th>\n",
       "      <th>minute_series</th>\n",
       "      <th>day_of_week</th>\n",
       "      <th>day_of_week_en</th>\n",
       "      <th>hour_en</th>\n",
       "      <th>week_hour_1.0,1.0</th>\n",
       "      <th>week_hour_1.0,2.0</th>\n",
       "      <th>week_hour_1.0,3.0</th>\n",
       "      <th>week_hour_2.0,1.0</th>\n",
       "      <th>week_hour_2.0,2.0</th>\n",
       "      <th>week_hour_2.0,3.0</th>\n",
       "      <th>week_hour_3.0,1.0</th>\n",
       "      <th>week_hour_3.0,2.0</th>\n",
       "      <th>week_hour_3.0,3.0</th>\n",
       "      <th>links_num_2</th>\n",
       "      <th>links_num_3</th>\n",
       "      <th>links_num_4</th>\n",
       "      <th>links_num_5</th>\n",
       "      <th>width_3</th>\n",
       "      <th>width_6</th>\n",
       "      <th>width_9</th>\n",
       "      <th>width_12</th>\n",
       "      <th>width_15</th>\n",
       "      <th>link_ID_en</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3377906280028510514</td>\n",
       "      <td>2017-03-01</td>\n",
       "      <td>2017-03-01 06:00:00</td>\n",
       "      <td>1.659311</td>\n",
       "      <td>True</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>48</td>\n",
       "      <td>144</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3377906280028510514</td>\n",
       "      <td>2017-03-01</td>\n",
       "      <td>2017-03-01 06:02:00</td>\n",
       "      <td>1.664941</td>\n",
       "      <td>True</td>\n",
       "      <td>1.659311</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>48</td>\n",
       "      <td>144</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3377906280028510514</td>\n",
       "      <td>2017-03-01</td>\n",
       "      <td>2017-03-01 06:04:00</td>\n",
       "      <td>1.671675</td>\n",
       "      <td>True</td>\n",
       "      <td>1.664941</td>\n",
       "      <td>1.659311</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>48</td>\n",
       "      <td>144</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3377906280028510514</td>\n",
       "      <td>2017-03-01</td>\n",
       "      <td>2017-03-01 06:06:00</td>\n",
       "      <td>1.676886</td>\n",
       "      <td>True</td>\n",
       "      <td>1.671675</td>\n",
       "      <td>1.664941</td>\n",
       "      <td>1.659311</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>48</td>\n",
       "      <td>144</td>\n",
       "      <td>0.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3377906280028510514</td>\n",
       "      <td>2017-03-01</td>\n",
       "      <td>2017-03-01 06:08:00</td>\n",
       "      <td>1.682314</td>\n",
       "      <td>True</td>\n",
       "      <td>1.676886</td>\n",
       "      <td>1.671675</td>\n",
       "      <td>1.664941</td>\n",
       "      <td>1.659311</td>\n",
       "      <td>NaN</td>\n",
       "      <td>48</td>\n",
       "      <td>144</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>47</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               link_ID        date time_interval_begin  travel_time  imputation1  lagging1  lagging2  lagging3  lagging4  lagging5  length  area  vacation  minute_series  day_of_week  day_of_week_en  hour_en  week_hour_1.0,1.0  week_hour_1.0,2.0  week_hour_1.0,3.0  week_hour_2.0,1.0  week_hour_2.0,2.0  week_hour_2.0,3.0  week_hour_3.0,1.0  week_hour_3.0,2.0  week_hour_3.0,3.0  links_num_2  links_num_3  links_num_4  links_num_5  width_3  width_6  width_9  width_12  width_15  link_ID_en\n",
       "0  3377906280028510514  2017-03-01 2017-03-01 06:00:00     1.659311         True       NaN       NaN       NaN       NaN       NaN      48   144       0.0            0.0            3             1.0      1.0                  1                  0                  0                  0                  0                  0                  0                  0                  0            1            0            0            0        1        0        0         0         0          47\n",
       "1  3377906280028510514  2017-03-01 2017-03-01 06:02:00     1.664941         True  1.659311       NaN       NaN       NaN       NaN      48   144       0.0            2.0            3             1.0      1.0                  1                  0                  0                  0                  0                  0                  0                  0                  0            1            0            0            0        1        0        0         0         0          47\n",
       "2  3377906280028510514  2017-03-01 2017-03-01 06:04:00     1.671675         True  1.664941  1.659311       NaN       NaN       NaN      48   144       0.0            4.0            3             1.0      1.0                  1                  0                  0                  0                  0                  0                  0                  0                  0            1            0            0            0        1        0        0         0         0          47\n",
       "3  3377906280028510514  2017-03-01 2017-03-01 06:06:00     1.676886         True  1.671675  1.664941  1.659311       NaN       NaN      48   144       0.0            6.0            3             1.0      1.0                  1                  0                  0                  0                  0                  0                  0                  0                  0            1            0            0            0        1        0        0         0         0          47\n",
       "4  3377906280028510514  2017-03-01 2017-03-01 06:08:00     1.682314         True  1.676886  1.671675  1.664941  1.659311       NaN      48   144       0.0            8.0            3             1.0      1.0                  1                  0                  0                  0                  0                  0                  0                  0                  0            1            0            0            0        1        0        0         0         0          47"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 读取处理好的特征数据\n",
    "df = pd.read_csv('data/trainning.txt', delimiter=';',parse_dates=['time_interval_begin'],dtype={'link_ID':object})\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['lagging5', 'lagging4', 'lagging3', 'lagging2', 'lagging1']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 时间序列特征\n",
    "lagging = 5\n",
    "lagging_feature = ['lagging%01d' % e for e in range(lagging, 0, -1)]\n",
    "lagging_feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_feature = [x for x in df.columns.values.tolist() if x not in ['time_interval_begin',\n",
    "                                                                  'link_ID','link_ID_int',\n",
    "                                                                  'date','travel_time',\n",
    "                                                                  'imputationl','minute_series',\n",
    "                                                                  'area','hour_en',\n",
    "                                                                   'day_of_week']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_feature = [x for x in base_feature if x not in lagging_feature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['imputation1', 'length', 'vacation', 'day_of_week_en', 'week_hour_1.0,1.0', 'week_hour_1.0,2.0', 'week_hour_1.0,3.0', 'week_hour_2.0,1.0', 'week_hour_2.0,2.0', 'week_hour_2.0,3.0', 'week_hour_3.0,1.0', 'week_hour_3.0,2.0', 'week_hour_3.0,3.0', 'links_num_2', 'links_num_3', 'links_num_4', 'links_num_5', 'width_3', 'width_6', 'width_9', 'width_12', 'width_15', 'link_ID_en', 'lagging5', 'lagging4', 'lagging3', 'lagging2', 'lagging1']\n"
     ]
    }
   ],
   "source": [
    "train_feature = list(base_feature)\n",
    "train_feature.extend(lagging_feature)\n",
    "valid_feature = list(base_feature)\n",
    "valid_feature.extend(['minute_series', 'travel_time'])\n",
    "print(train_feature)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "xgboost训练参数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_grid = {\n",
    "    'learning_rate':[0.05],\n",
    "    'n_estimators':[100],\n",
    "    'subsample':[0.6],\n",
    "    'colsample_bytree':[0.6],\n",
    "    'max_depth':[7],\n",
    "    'min_child_weight':[1],\n",
    "    'reg_alpha':[2],\n",
    "    'gamma':[0]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ParameterGrid\n",
    "grid = ParameterGrid(params_grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bucket_data(lines):\n",
    "    bucket = {}\n",
    "    for line in lines:\n",
    "        time_series = line[-2]\n",
    "        bucket[time_series] = []\n",
    "    for line in lines:\n",
    "        time_series, y1 = line[-2:]\n",
    "        line = np.delete(line, -2, axis=0)\n",
    "        bucket[time_series].append(line)\n",
    "    return bucket\n",
    "\n",
    "\n",
    "def cross_valid(regressor, bucket, lagging):\n",
    "    valid_loss = []\n",
    "    last = [[] for i in range(len(bucket[list(bucket.keys())[0]]))]\n",
    "    for time_series in sorted(bucket.keys(), key=float):\n",
    "        if time_series >= 120:\n",
    "            if int(time_series) in range(120,120+lagging*2,2):\n",
    "                last = np.concatenate((last, np.array(bucket[time_series], dtype=float)[:, -1].reshape(-1,1)),axis=1)\n",
    "            else:\n",
    "                batch = np.array(bucket[time_series], dtype=float)\n",
    "                y = batch[:,-1]\n",
    "                batch = np.delete(batch, -1, axis=1)\n",
    "                batch = np.concatenate((batch, last), axis=1)\n",
    "                y_pre = regressor.predict(batch)\n",
    "                last = np.delete(last, 0, axis=1)\n",
    "                last = np.concatenate((last, y_pre.reshape(-1,1)),axis=1)\n",
    "                loss = np.mean(abs(np.expm1(y) - np.expm1(y_pre))/np.expm1(y))\n",
    "                valid_loss.append(loss)\n",
    "    return np.mean(valid_loss)\n",
    "\n",
    "\n",
    "def mape_ln(y, d):\n",
    "    c = d.get_label()\n",
    "    result = np.sum(np.abs((np.expm1(y)-np.expm1(c))/np.expm1(c)))/len(c)\n",
    "    return 'mape', result\n",
    "\n",
    "\n",
    "def submission(train_feature, regressor,df, file1,file2,file3,file4):\n",
    "    test_df = df.loc[((df['time_interval_begin'].dt.year==2017)&(df['time_interval_begin'].dt.month==7)\n",
    "                     &(df['time_interval_begin'].dt.hour.isin([7,14,17]))\n",
    "                      &(df['time_interval_begin'].dt.minute==58))].copy()\n",
    "    test_df['lagging5'] = test_df['lagging4']\n",
    "    test_df['lagging4'] = test_df['lagging3']\n",
    "    test_df['lagging3'] = test_df['lagging2']\n",
    "    test_df['lagging2'] = test_df['lagging1']\n",
    "    test_df['lagging1'] = test_df['travel_time']\n",
    "    with open(file1, 'w'):\n",
    "        pass\n",
    "    with open(file2, 'w'):\n",
    "        pass\n",
    "    with open(file3, 'w'):\n",
    "        pass\n",
    "    with open(file4, 'w'):\n",
    "        pass\n",
    "    for i in range(30):\n",
    "        test_X = test_df[train_feature]\n",
    "        y_prediction = regressor.predict(test_X.values)\n",
    "        test_df['lagging5'] = test_df['lagging4']\n",
    "        test_df['lagging4'] = test_df['lagging3']\n",
    "        test_df['lagging3'] = test_df['lagging2']\n",
    "        test_df['lagging2'] = test_df['lagging1']\n",
    "        test_df['lagging1'] = y_prediction\n",
    "        \n",
    "        test_df['prediction'] = np.expm1(y_prediction)\n",
    "        test_df['time_interval_begin'] = test_df['time_interval_begin']+pd.DateOffset(minutes=2)\n",
    "        test_df['time_interval'] = test_df['time_interval_begin'].map(\n",
    "            lambda x: '[' + str(x)+','+str(x+pd.DateOffset(minutes=2))+')')\n",
    "        test_df.time_interval = test_df.time_interval.astype(object)\n",
    "        if i < 7:\n",
    "            test_df[['link_ID','date','time_interval','prediction']].to_csv(file1,mode='a',\n",
    "                                                                          header=False,\n",
    "                                                                          index=False,\n",
    "                                                                          sep=';')\n",
    "        elif (7 <= i) and (i < 14):\n",
    "            test_df[['link_ID','date','time_interval','prediction']].to_csv(file2,mode='a',\n",
    "                                                                          header=False,\n",
    "                                                                          index=False,\n",
    "                                                                          sep=';')\n",
    "        elif (14 <= i) and (i < 22):\n",
    "            test_df[['link_ID','date','time_interval','prediction']].to_csv(file1,mode='a',\n",
    "                                                                          header=False,\n",
    "                                                                          index=False,\n",
    "                                                                          sep=';')\n",
    "        else:\n",
    "            test_df[['link_ID','date','time_interval','prediction']].to_csv(file4,mode='a',\n",
    "                                                                          header=False,\n",
    "                                                                          index=False,\n",
    "                                                                          sep=';')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "def fit_evaluate(df, df_test, params):\n",
    "    df = df.dropna()\n",
    "    X = df[train_feature].values\n",
    "    y = df['travel_time'].values\n",
    "    \n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=0)\n",
    "    \n",
    "    df_test = df_test[valid_feature].values\n",
    "    valid_data = bucket_data(df_test)\n",
    "    \n",
    "    eval_set = [(X_test, y_test)]\n",
    "    regressor = xgb.XGBRegressor(learning_rate=params['learning_rate'],\n",
    "                             n_estimators=params['n_estimators'],\n",
    "                            booster='gbtree', objective='reg:linear',\n",
    "                            n_jobs=-1,subsample=params['subsample'],\n",
    "                            colsample_bytree=params['colsample_bytree'],\n",
    "                            random_state=0,max_depth=params['max_depth'],\n",
    "                            gamma=params['gamma'],\n",
    "                             min_child_weight=params['min_child_weight'],\n",
    "                            reg_alpha=params['reg_alpha'])\n",
    "    regressor.fit(X_train,y_train,verbose=False,early_stopping_rounds=10,eval_set=eval_set)\n",
    "    return regressor, cross_valid(regressor, valid_data, lagging=lagging), regressor.best_iteration,regressor.best_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(df, params, best, vis=False):\n",
    "    train1 = df.loc[df['time_interval_begin'] <= pd.to_datetime('2017-03-24')]\n",
    "    train2 = df.loc[\n",
    "        (df['time_interval_begin']>pd.to_datetime('2017-03-24'))&(\n",
    "        df['time_interval_begin'] <= pd.to_datetime('2017-04-18'))]\n",
    "    train3 = df.loc[\n",
    "        (df['time_interval_begin']>pd.to_datetime('2017-04-18'))&(\n",
    "        df['time_interval_begin'] <= pd.to_datetime('2017-05-12'))]\n",
    "    train4 = df.loc[\n",
    "        (df['time_interval_begin']>pd.to_datetime('2017-05-12'))&(\n",
    "        df['time_interval_begin'] <= pd.to_datetime('2017-06-06'))]\n",
    "    train5 = df.loc[\n",
    "        (df['time_interval_begin']>pd.to_datetime('2017-06-06'))&(\n",
    "        df['time_interval_begin'] <= pd.to_datetime('2017-06-30'))]\n",
    "    \n",
    "    regressor, loss1, best_iteration1,best_score1 = fit_evaluate(pd.concat([train1,\n",
    "                                                                            train2,\n",
    "                                                                           train3,\n",
    "                                                                           train4]),train5,\n",
    "                                                                params)\n",
    "    print(best_iteration1,best_score1,loss1)\n",
    "    \n",
    "    regressor, loss2, best_iteration2,best_score2 = fit_evaluate(pd.concat([train1,\n",
    "                                                                            train2,\n",
    "                                                                           train3,\n",
    "                                                                           train5]),train4,\n",
    "                                                                params)    \n",
    "    print(best_iteration2,best_score2,loss2)\n",
    "    \n",
    "    regressor, loss3, best_iteration3,best_score3 = fit_evaluate(pd.concat([train1,\n",
    "                                                                            train2,\n",
    "                                                                           train4,\n",
    "                                                                           train5]),train3,\n",
    "                                                                params)    \n",
    "    print(best_iteration3,best_score3,loss3) \n",
    "\n",
    "    regressor, loss4, best_iteration4,best_score4 = fit_evaluate(pd.concat([train1,\n",
    "                                                                           train3,\n",
    "                                                                           train4,\n",
    "                                                                           train5]),train2,\n",
    "                                                                params) \n",
    "    print(best_iteration4,best_score4,loss4)     \n",
    "\n",
    "    regressor, loss5, best_iteration5,best_score5 = fit_evaluate(pd.concat([train2,\n",
    "                                                                           train3,\n",
    "                                                                           train4,\n",
    "                                                                           train5]),train1,\n",
    "                                                                params)\n",
    "    print(best_iteration5,best_score5,loss5) \n",
    "    \n",
    "    loss = [loss1,loss2, loss3, loss4, loss5]\n",
    "    params['loss_std'] = np.std(loss)\n",
    "    params['loss'] = str(loss)\n",
    "    params['mean_loss'] = np.mean(loss)\n",
    "    params['n_estimators'] = str([best_iteration1, best_iteration2, best_iteration3,\n",
    "                                 best_iteration4, best_iteration5])\n",
    "    params['best_score'] = str([best_score1, best_score2, best_score3,\n",
    "                                 best_score4, best_score5])\n",
    "    \n",
    "    print(str(params))\n",
    "    if np.mean(loss) <= best:\n",
    "        best = np.mean(loss)\n",
    "        print('best with:' + str(params))\n",
    "    return best"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[22:07:01] WARNING: C:/Jenkins/workspace/xgboost-win64_release_0.90/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.\n",
      "99 0.231729 0.09787323564628972\n",
      "[22:12:48] WARNING: C:/Jenkins/workspace/xgboost-win64_release_0.90/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.\n",
      "99 0.211948 0.22588986922596394\n",
      "[22:18:32] WARNING: C:/Jenkins/workspace/xgboost-win64_release_0.90/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.\n",
      "99 0.207832 0.269828138777363\n",
      "[22:24:17] WARNING: C:/Jenkins/workspace/xgboost-win64_release_0.90/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.\n",
      "99 0.205743 0.27878690843594917\n",
      "[22:29:46] WARNING: C:/Jenkins/workspace/xgboost-win64_release_0.90/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.\n",
      "99 0.206546 0.2825731100341743\n",
      "{'colsample_bytree': 0.6, 'gamma': 0, 'learning_rate': 0.05, 'max_depth': 7, 'min_child_weight': 1, 'n_estimators': '[99, 99, 99, 99, 99]', 'reg_alpha': 2, 'subsample': 0.6, 'loss_std': 0.06956988861011186, 'loss': '[0.09787323564628972, 0.22588986922596394, 0.269828138777363, 0.27878690843594917, 0.2825731100341743]', 'mean_loss': 0.23099025242394805, 'best_score': '[0.231729, 0.211948, 0.207832, 0.205743, 0.206546]'}\n",
      "best with:{'colsample_bytree': 0.6, 'gamma': 0, 'learning_rate': 0.05, 'max_depth': 7, 'min_child_weight': 1, 'n_estimators': '[99, 99, 99, 99, 99]', 'reg_alpha': 2, 'subsample': 0.6, 'loss_std': 0.06956988861011186, 'loss': '[0.09787323564628972, 0.22588986922596394, 0.269828138777363, 0.27878690843594917, 0.2825731100341743]', 'mean_loss': 0.23099025242394805, 'best_score': '[0.231729, 0.211948, 0.207832, 0.205743, 0.206546]'}\n"
     ]
    }
   ],
   "source": [
    "best = 1\n",
    "for params in grid:\n",
    "    best = train(df, params, best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "submit_params = {\n",
    "    'learning_rate':0.05,\n",
    "    'n_estimators':100,\n",
    "    'subsample':0.6,\n",
    "    'colsample_bytree':0.6,\n",
    "    'max_depth':7,\n",
    "    'min_child_weight':1,\n",
    "    'reg_alpha':2,\n",
    "    'gamma':0\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def xgboost_submit(df, params):\n",
    "    train_df = df.loc[df['time_interval_begin']<pd.to_datetime('2017-07-01')]\n",
    "    \n",
    "    train_df = train_df.dropna()\n",
    "    X = train_df[train_feature].values\n",
    "    y = train_df['travel_time'].values\n",
    "    \n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=0)\n",
    "    \n",
    "    eval_set = [(X_test, y_test)]\n",
    "    regressor = xgb.XGBRegressor(learning_rate=params['learning_rate'],\n",
    "                             n_estimators=params['n_estimators'],\n",
    "                            booster='gbtree', objective='reg:linear',\n",
    "                            n_jobs=-1,subsample=params['subsample'],\n",
    "                            colsample_bytree=params['colsample_bytree'],\n",
    "                            random_state=0,max_depth=params['max_depth'],\n",
    "                            gamma=params['gamma'],\n",
    "                             min_child_weight=params['min_child_weight'],\n",
    "                            reg_alpha=params['reg_alpha'])\n",
    "    regressor.fit(X_train,y_train,verbose=True,early_stopping_rounds=10,\n",
    "                  eval_metric=mape_ln,eval_set=eval_set)\n",
    "    try:\n",
    "        os.mkdir(\"model/\")  # 尝试创建相对目录，有则跳过\n",
    "    except:\n",
    "        pass\n",
    "    joblib.dump(regressor, 'model/xgbr.pkl')\n",
    "    print(regressor)\n",
    "    try:\n",
    "        os.mkdir(\"submission/\")  # 尝试创建相对目录，有则跳过\n",
    "    except:\n",
    "        pass\n",
    "    submission(train_feature, regressor,df, \n",
    "               'submission/xgbrl.txt','submission/xgbr2.txt',\n",
    "              'submission/xgbr3.txt','submission/xgbr4.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[22:35:34] WARNING: C:/Jenkins/workspace/xgboost-win64_release_0.90/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.\n",
      "[0]\tvalidation_0-rmse:2.02747\tvalidation_0-mape:0.867894\n",
      "Multiple eval metrics have been passed: 'validation_0-mape' will be used for early stopping.\n",
      "\n",
      "Will train until validation_0-mape hasn't improved in 10 rounds.\n",
      "[1]\tvalidation_0-rmse:1.92734\tvalidation_0-mape:0.850712\n",
      "[2]\tvalidation_0-rmse:1.83231\tvalidation_0-mape:0.83309\n",
      "[3]\tvalidation_0-rmse:1.74204\tvalidation_0-mape:0.815116\n",
      "[4]\tvalidation_0-rmse:1.65635\tvalidation_0-mape:0.796798\n",
      "[5]\tvalidation_0-rmse:1.57575\tvalidation_0-mape:0.777818\n",
      "[6]\tvalidation_0-rmse:1.49911\tvalidation_0-mape:0.758668\n",
      "[7]\tvalidation_0-rmse:1.4258\tvalidation_0-mape:0.739614\n",
      "[8]\tvalidation_0-rmse:1.35624\tvalidation_0-mape:0.720407\n",
      "[9]\tvalidation_0-rmse:1.29025\tvalidation_0-mape:0.701088\n",
      "[10]\tvalidation_0-rmse:1.22764\tvalidation_0-mape:0.681732\n",
      "[11]\tvalidation_0-rmse:1.16884\tvalidation_0-mape:0.662053\n",
      "[12]\tvalidation_0-rmse:1.11249\tvalidation_0-mape:0.642755\n",
      "[13]\tvalidation_0-rmse:1.05955\tvalidation_0-mape:0.62329\n",
      "[14]\tvalidation_0-rmse:1.00937\tvalidation_0-mape:0.603984\n",
      "[15]\tvalidation_0-rmse:0.961824\tvalidation_0-mape:0.584883\n",
      "[16]\tvalidation_0-rmse:0.916274\tvalidation_0-mape:0.56632\n",
      "[17]\tvalidation_0-rmse:0.873956\tvalidation_0-mape:0.547599\n",
      "[18]\tvalidation_0-rmse:0.833681\tvalidation_0-mape:0.529301\n",
      "[19]\tvalidation_0-rmse:0.7949\tvalidation_0-mape:0.511751\n",
      "[20]\tvalidation_0-rmse:0.758183\tvalidation_0-mape:0.494576\n",
      "[21]\tvalidation_0-rmse:0.724009\tvalidation_0-mape:0.477586\n",
      "[22]\tvalidation_0-rmse:0.69164\tvalidation_0-mape:0.46109\n",
      "[23]\tvalidation_0-rmse:0.660537\tvalidation_0-mape:0.445265\n",
      "[24]\tvalidation_0-rmse:0.631783\tvalidation_0-mape:0.429777\n",
      "[25]\tvalidation_0-rmse:0.604043\tvalidation_0-mape:0.414922\n",
      "[26]\tvalidation_0-rmse:0.5778\tvalidation_0-mape:0.400545\n",
      "[27]\tvalidation_0-rmse:0.553042\tvalidation_0-mape:0.386658\n",
      "[28]\tvalidation_0-rmse:0.530196\tvalidation_0-mape:0.373128\n",
      "[29]\tvalidation_0-rmse:0.508169\tvalidation_0-mape:0.36026\n",
      "[30]\tvalidation_0-rmse:0.487437\tvalidation_0-mape:0.3479\n",
      "[31]\tvalidation_0-rmse:0.468318\tvalidation_0-mape:0.336023\n",
      "[32]\tvalidation_0-rmse:0.450428\tvalidation_0-mape:0.324693\n",
      "[33]\tvalidation_0-rmse:0.433156\tvalidation_0-mape:0.31389\n",
      "[34]\tvalidation_0-rmse:0.416931\tvalidation_0-mape:0.303565\n",
      "[35]\tvalidation_0-rmse:0.401729\tvalidation_0-mape:0.293732\n",
      "[36]\tvalidation_0-rmse:0.387835\tvalidation_0-mape:0.284394\n",
      "[37]\tvalidation_0-rmse:0.37451\tvalidation_0-mape:0.275516\n",
      "[38]\tvalidation_0-rmse:0.362507\tvalidation_0-mape:0.267262\n",
      "[39]\tvalidation_0-rmse:0.35083\tvalidation_0-mape:0.259271\n",
      "[40]\tvalidation_0-rmse:0.339946\tvalidation_0-mape:0.251728\n",
      "[41]\tvalidation_0-rmse:0.330117\tvalidation_0-mape:0.244688\n",
      "[42]\tvalidation_0-rmse:0.321113\tvalidation_0-mape:0.238073\n",
      "[43]\tvalidation_0-rmse:0.312602\tvalidation_0-mape:0.231857\n",
      "[44]\tvalidation_0-rmse:0.304387\tvalidation_0-mape:0.225904\n",
      "[45]\tvalidation_0-rmse:0.296781\tvalidation_0-mape:0.220307\n",
      "[46]\tvalidation_0-rmse:0.290079\tvalidation_0-mape:0.21522\n",
      "[47]\tvalidation_0-rmse:0.283964\tvalidation_0-mape:0.210527\n",
      "[48]\tvalidation_0-rmse:0.278207\tvalidation_0-mape:0.206073\n",
      "[49]\tvalidation_0-rmse:0.272557\tvalidation_0-mape:0.201743\n",
      "[50]\tvalidation_0-rmse:0.267379\tvalidation_0-mape:0.197712\n",
      "[51]\tvalidation_0-rmse:0.262584\tvalidation_0-mape:0.193932\n",
      "[52]\tvalidation_0-rmse:0.258517\tvalidation_0-mape:0.190628\n",
      "[53]\tvalidation_0-rmse:0.254437\tvalidation_0-mape:0.187326\n",
      "[54]\tvalidation_0-rmse:0.250997\tvalidation_0-mape:0.184496\n",
      "[55]\tvalidation_0-rmse:0.247837\tvalidation_0-mape:0.181864\n",
      "[56]\tvalidation_0-rmse:0.244739\tvalidation_0-mape:0.179262\n",
      "[57]\tvalidation_0-rmse:0.242088\tvalidation_0-mape:0.176975\n",
      "[58]\tvalidation_0-rmse:0.239432\tvalidation_0-mape:0.174694\n",
      "[59]\tvalidation_0-rmse:0.236956\tvalidation_0-mape:0.172544\n",
      "[60]\tvalidation_0-rmse:0.23472\tvalidation_0-mape:0.170567\n",
      "[61]\tvalidation_0-rmse:0.232673\tvalidation_0-mape:0.168749\n",
      "[62]\tvalidation_0-rmse:0.230954\tvalidation_0-mape:0.167218\n",
      "[63]\tvalidation_0-rmse:0.229382\tvalidation_0-mape:0.165806\n",
      "[64]\tvalidation_0-rmse:0.227969\tvalidation_0-mape:0.164547\n",
      "[65]\tvalidation_0-rmse:0.226601\tvalidation_0-mape:0.163268\n",
      "[66]\tvalidation_0-rmse:0.22546\tvalidation_0-mape:0.162225\n",
      "[67]\tvalidation_0-rmse:0.224374\tvalidation_0-mape:0.161249\n",
      "[68]\tvalidation_0-rmse:0.223225\tvalidation_0-mape:0.160179\n",
      "[69]\tvalidation_0-rmse:0.222167\tvalidation_0-mape:0.159196\n",
      "[70]\tvalidation_0-rmse:0.221212\tvalidation_0-mape:0.158295\n",
      "[71]\tvalidation_0-rmse:0.220377\tvalidation_0-mape:0.157483\n",
      "[72]\tvalidation_0-rmse:0.219618\tvalidation_0-mape:0.156731\n",
      "[73]\tvalidation_0-rmse:0.219029\tvalidation_0-mape:0.156166\n",
      "[74]\tvalidation_0-rmse:0.218453\tvalidation_0-mape:0.155643\n",
      "[75]\tvalidation_0-rmse:0.217805\tvalidation_0-mape:0.154995\n",
      "[76]\tvalidation_0-rmse:0.217225\tvalidation_0-mape:0.154441\n",
      "[77]\tvalidation_0-rmse:0.216778\tvalidation_0-mape:0.154038\n",
      "[78]\tvalidation_0-rmse:0.21637\tvalidation_0-mape:0.153691\n",
      "[79]\tvalidation_0-rmse:0.21592\tvalidation_0-mape:0.153251\n",
      "[80]\tvalidation_0-rmse:0.215582\tvalidation_0-mape:0.152955\n",
      "[81]\tvalidation_0-rmse:0.2153\tvalidation_0-mape:0.152717\n",
      "[82]\tvalidation_0-rmse:0.215012\tvalidation_0-mape:0.152489\n",
      "[83]\tvalidation_0-rmse:0.214672\tvalidation_0-mape:0.15214\n",
      "[84]\tvalidation_0-rmse:0.214429\tvalidation_0-mape:0.151959\n",
      "[85]\tvalidation_0-rmse:0.214134\tvalidation_0-mape:0.151676\n",
      "[86]\tvalidation_0-rmse:0.213853\tvalidation_0-mape:0.151404\n",
      "[87]\tvalidation_0-rmse:0.213573\tvalidation_0-mape:0.151144\n",
      "[88]\tvalidation_0-rmse:0.213402\tvalidation_0-mape:0.151023\n",
      "[89]\tvalidation_0-rmse:0.213262\tvalidation_0-mape:0.15091\n",
      "[90]\tvalidation_0-rmse:0.21304\tvalidation_0-mape:0.150693\n",
      "[91]\tvalidation_0-rmse:0.212849\tvalidation_0-mape:0.150517\n",
      "[92]\tvalidation_0-rmse:0.212735\tvalidation_0-mape:0.150436\n",
      "[93]\tvalidation_0-rmse:0.21258\tvalidation_0-mape:0.150284\n",
      "[94]\tvalidation_0-rmse:0.212416\tvalidation_0-mape:0.150137\n",
      "[95]\tvalidation_0-rmse:0.212256\tvalidation_0-mape:0.150004\n",
      "[96]\tvalidation_0-rmse:0.212172\tvalidation_0-mape:0.149958\n",
      "[97]\tvalidation_0-rmse:0.212046\tvalidation_0-mape:0.149827\n",
      "[98]\tvalidation_0-rmse:0.211958\tvalidation_0-mape:0.149779\n",
      "[99]\tvalidation_0-rmse:0.211834\tvalidation_0-mape:0.149654\n",
      "XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
      "             colsample_bynode=1, colsample_bytree=0.6, gamma=0,\n",
      "             importance_type='gain', learning_rate=0.05, max_delta_step=0,\n",
      "             max_depth=7, min_child_weight=1, missing=None, n_estimators=100,\n",
      "             n_jobs=-1, nthread=None, objective='reg:linear', random_state=0,\n",
      "             reg_alpha=2, reg_lambda=1, scale_pos_weight=1, seed=None,\n",
      "             silent=None, subsample=0.6, verbosity=1)\n"
     ]
    }
   ],
   "source": [
    "xgboost_submit(df, submit_params)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
